# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, List, Optional, Union
# from args import *
import torch
import torch.nn as nn

from peft.tuners.tuners_utils import BaseTunerLayer
import math
from .dct_utils import idct_2d_impl, dct_2d_impl
    
class IdctWithIndexGrad(torch.autograd.Function):
    
    dct_mode = 'default'

    @staticmethod
    def forward(ctx, updates, locations, row, col):
        locations = locations.round().long()
        ctx.save_for_backward(torch.tensor([row, col]), locations, updates)
        return idct_2d_impl(updates, locations, row, col, IdctWithIndexGrad.dct_mode)
    
    @staticmethod
    def backward(ctx, grad_output):
        input_shape, locations, updates = ctx.saved_tensors
        index_row, index_col = locations[0,:], locations[1,:]
        grad_input, grad_index_row, grad_index_col = None, None, None
        K_matrix = dct_2d_impl(grad_output, IdctWithIndexGrad.dct_mode)
        if ctx.needs_input_grad[0]:
            grad_input = K_matrix[index_row, index_col]
        
        if ctx.needs_input_grad[1]:
            lower_index = index_row - 1
            upper_index = index_row + 1
            lower_index = torch.where(index_row > 0, lower_index, torch.zeros_like(upper_index))
            upper_index = torch.where(index_row < input_shape[0] - 1, upper_index, torch.full_like(upper_index, input_shape[0] - 1))
            grad_index_row = (1/2* updates * (K_matrix[upper_index, index_col] - K_matrix[lower_index, index_col])).clamp(min=-1, max=1)

            left_index = index_col - 1
            right_index = index_col + 1
            left_index = torch.where(index_col > 0, left_index, torch.zeros_like(right_index))
            right_index = torch.where(index_col < input_shape[1] - 1, right_index, torch.full_like(right_index, input_shape[1] - 1))
            grad_index_col = (1/2 * updates * (K_matrix[index_row, right_index] - K_matrix[index_row, left_index])).clamp(min=-1, max=1)

        return grad_input.view(-1) if ctx.needs_input_grad[0] else None, torch.stack([grad_index_row, grad_index_col], dim=0) if ctx.needs_input_grad[1] else None, None, None

class LoCALayer(BaseTunerLayer):
    # All names of layers that may contain (trainable) adapter weights
    adapter_layer_names = ["spectrum", "spectrum_indices"]
    # All names of other parameters that may contain adapter-related parameters
    # other_param_names = ("rank", "dropout")

    def __init__(self, base_layer: nn.Module, **kwargs) -> None:
        self.base_layer = base_layer
        self.n_frequency = {}
        self.scale = {}
        self.spectrum = nn.ParameterDict({})
        self.spectrum_indices = nn.ParameterDict({})
        self.loca_dropout_layer = {}
        self.loca_dct_mode = None,
        # Mark the weight as unmerged
        self._disable_adapters = False
        self.merged_adapters = []
        self.kwargs = kwargs

        base_layer = self.get_base_layer()
        if isinstance(base_layer, nn.Linear):
            in_features, out_features = base_layer.in_features, base_layer.out_features

        elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
            # QuantLinear
            in_features, out_features = base_layer.infeatures, base_layer.outfeatures
        elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
            # Megatron ColumnParallelLinear,RowParallelLinear
            in_features, out_features = base_layer.input_size, base_layer.output_size
        else:
            raise ValueError(f"Unsupported layer type {type(base_layer)}")

        self.in_features = in_features
        self.out_features = out_features

        self.global_iter = 0
        self.learn_location_iter = None
        self.sparse_idct2d_with_index_gradient = IdctWithIndexGrad.apply

    def update_layer(self, adapter_name, n_frequency, scale, loca_dropout, learn_location_iter, loca_dct_mode, init_loca_weights=None):
        if n_frequency <= 0:
            raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}")
        self.n_frequency[adapter_name] = n_frequency
        self.scale[adapter_name] = scale
        self.learn_location_iter = learn_location_iter
        self.loca_dct_mode = loca_dct_mode
        IdctWithIndexGrad.dct_mode = self.loca_dct_mode
        if n_frequency > 0:  
            self.spectrum_indices[adapter_name] = nn.Parameter(torch.zeros(2, n_frequency))
            self.spectrum[adapter_name] = nn.Parameter(torch.zeros(n_frequency))
            nn.init.uniform_(self.spectrum_indices[adapter_name], 0, 1)
            nn.init.zeros_(self.spectrum[adapter_name])
        if loca_dropout > 0.0:
            self.loca_dropout_layer[adapter_name] = nn.Dropout(p=loca_dropout)
        else:
            self.loca_dropout_layer[adapter_name] = nn.Identity()

        weight = getattr(self.get_base_layer(), "weight", None)
        if weight is not None:
            # the layer is already completely initialized, this is an update
            if weight.dtype.is_floating_point or weight.dtype.is_complex:
                self.to(weight.device, dtype=weight.dtype)
            else:
                self.to(weight.device)
        self.set_adapter(self.active_adapters)

    def reset_loca_parameters(self, adapter_name, init_loca_weights):
        if init_loca_weights is False:
            return
        if adapter_name in self.spectrum.keys():
            nn.init.zeros_(self.spectrum[adapter_name])

class Linear(nn.Module, LoCALayer):
    # loca implemented in a dense layer
    def __init__(
        self,
        base_layer,
        adapter_name: str,
        n_frequency: int = 0,
        scale: float = 0.1,
        loca_dropout: float = 0.0,
        loca_dct_mode: str = 'default',
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        init_loca_weights: Union[bool, str] = True,
        learn_location_iter: int =1000,
        **kwargs,
    ) -> None:
        super().__init__()
        LoCALayer.__init__(self, base_layer, **kwargs)
        self.fan_in_fan_out = fan_in_fan_out

        self._active_adapter = adapter_name
        self.update_layer(adapter_name, n_frequency, scale, loca_dropout, learn_location_iter, loca_dct_mode, init_loca_weights)

    def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
        """
        Merge the active adapter weights into the base weights

        Args:
            safe_merge (`bool`, *optional*):
                If True, the merge operation will be performed in a copy of the original weights and check for NaNs
                before merging the weights. This is useful if you want to check if the merge operation will produce
                NaNs. Defaults to `False`.
            adapter_names (`List[str]`, *optional*):
                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
                to `None`.
        """
        if self.merged:
            warnings.warn(
                f"Already following adapters were merged {','.join(self.merged_adapters)}. "
                f"You are now additionally merging {','.join(self.active_adapters)}."
            )

        if adapter_names is None:
            adapter_names = self.active_adapters

        for active_adapter in adapter_names:
            if active_adapter in self.spectrum.keys():
                base_layer = self.get_base_layer()
                if safe_merge:
                    # Note that safe_merge will be slower than the normal merge
                    # because of the copy operation.
                    orig_weights = base_layer.weight.data.clone()
                    orig_weights += self.get_delta_weight(active_adapter)

                    if not torch.isfinite(orig_weights).all():
                        raise ValueError(
                            f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                        )

                    base_layer.weight.data = orig_weights
                else:
                    base_layer.weight.data += self.get_delta_weight(active_adapter)
                self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        """
        This method unmerges all merged adapter layers from the base weights.
        """
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            if active_adapter in self.spectrum.keys():
                self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)


    def spectrum_to_para(self, adapter):
        spectral_para = self.spectrum[adapter]
        lo_clip = torch.clamp(self.spectrum_indices[adapter], min=0, max=0.999)
        lo_clip[0,:], lo_clip[1,:] = lo_clip[0,:] * (self.out_features -1), lo_clip[1,:] * (self.in_features -1)
        if self.spectrum_indices[adapter].requires_grad:
            return self.sparse_idct2d_with_index_gradient(spectral_para, lo_clip, self.out_features, self.in_features)
        else:
            return idct_2d_impl(spectral_para, lo_clip.round().long(), self.out_features, self.in_features, self.loca_dct_mode)

    def get_delta_weight(self, adapter) -> torch.Tensor:
        """
        Compute the delta weight for the given adapter.

        Args:
            adapter (str):
                The name of the adapter for which the delta weight should be computed.
        """
        def T(w):
            return w.T if self.fan_in_fan_out else w
        device = self.spectrum[adapter].device
        dtype = self.spectrum[adapter].dtype

        # In case users wants to merge the adapter weights that are in
        # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
        # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16.
        cast_to_fp32 = device.type == "cpu" and dtype == torch.float16

        # spectrum = self.spectrum[adapter]
        # indices = self.indices[adapter].to(spectrum.device)
        
        weight = T(self.spectrum_to_para(adapter)) * self.scale[adapter]
        if cast_to_fp32:
            weight = weight.float()

        output_tensor = weight

        if cast_to_fp32:
            output_tensor = output_tensor.to(dtype=dtype)

            # cast back the weights
            self.weight[adapter] = weight.to(dtype)

        return output_tensor

    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
        if self.training:
            if self.global_iter == 0:
                self.spectrum['default'].requires_grad = True
                self.spectrum_indices['default'].requires_grad = False
            else:
                if self.global_iter < self.learn_location_iter:
                    cycle_position = self.global_iter % 30    
                    if cycle_position == 0:
                        self.spectrum['default'].requires_grad = True
                        self.spectrum_indices['default'].requires_grad = False
                    if cycle_position == 10:
                        self.spectrum['default'].requires_grad = False
                        self.spectrum_indices['default'].requires_grad = True
                if self.global_iter == self.learn_location_iter:
                    self.spectrum['default'].requires_grad = True
                    self.spectrum_indices['default'].requires_grad = False                
                    # print("spectrum_indices:", self.spectrum_indices['default'])
            self.global_iter += 1
        previous_dtype = x.dtype
        if (not self.training) and (not self.merged):
            self.merge()
        if self.training and self.merged:
            self.unmerge() 
        if self.disable_adapters:
            if self.merged:
                self.unmerge()
            result = self.base_layer(x, *args, **kwargs)
        elif self.merged:
            result = self.base_layer(x, *args, **kwargs)
        else:
            result = self.base_layer(x, *args, **kwargs)
            for active_adapter in self.active_adapters:
                if active_adapter not in self.spectrum.keys():
                    continue
                # print(self.spectrum[active_adapter].grad)
                dropout = self.loca_dropout_layer[active_adapter]
                delta_w = self.get_delta_weight(active_adapter)
                result += dropout(x) @ delta_w.T

        result = result.to(previous_dtype)
        return result

    def __repr__(self) -> str:
        rep = super().__repr__()
        return "loca." + rep